{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Overview: \n",
        "This code implements one of the multiple ways of multi-model RAG. It extracts and processes text and images from PDFs, utilizing a multi-modal Retrieval-Augmented Generation (RAG) system for summarizing and retrieving content for question answering.\n",
        "\n",
        "### Key Components:\n",
        "   - **PyMuPDF**: For extracting text and images from PDFs.\n",
        "   - **Gemini 1.5-flash model**: To summarize images and tables.\n",
        "   - **Cohere Embeddings**: For embedding document splits.\n",
        "   - **Chroma Vectorstore**: To store and retrieve document embeddings.\n",
        "   - **LangChain**: To orchestrate the retrieval and generation pipeline."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Diagram:\n",
        "   <img src=\"../images/multi_model_rag_with_captioning.svg\" alt=\"Reliable-RAG\" width=\"300\">\n",
        "\n",
        "### Motivation: \n",
        "Efficiently summarize complex documents to facilitate easy retrieval and concise responses for multi-modal data.\n",
        "\n",
        "### Method Details:\n",
        "   - Text and images are extracted from the PDF using PyMuPDF.\n",
        "   - Summarization is performed on extracted images and tables using Gemini.\n",
        "   - Embeddings are generated via Cohere for storage in Chroma.\n",
        "   - A similarity-based retriever fetches relevant sections based on the query.\n",
        "\n",
        "### Benefits:\n",
        "   - Simplified retrieval from complex, multi-modal documents.\n",
        "   - Streamlined Q&A process for both text and images.\n",
        "   - Flexible architecture for expanding to more document types.\n",
        "\n",
        "### Implementation:\n",
        "   - Documents are split into chunks with overlap using a text splitter.\n",
        "   - Summarized text and image content are stored as vectors.\n",
        "   - Queries are handled by retrieving relevant document segments and generating concise answers.\n",
        "\n",
        "### Summary: \n",
        "The project enables multi-modal document processing and retrieval, providing concise, relevant responses by combining state-of-the-art LLMs and vector-based retrieval systems."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Package Installation and Imports\n",
        "\n",
        "The cell below installs all necessary packages required to run this notebook.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Install required packages\n",
        "!pip install langchain langchain-community pillow pymupdf python-dotenv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "execution_count": 1,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import fitz  # PyMuPDF\n",
        "from PIL import Image\n",
        "import io\n",
        "import os\n",
        "from dotenv import load_dotenv\n",
        "\n",
        "import google.generativeai as genai\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from langchain_core.documents import Document\n",
        "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
        "from langchain_community.vectorstores import Chroma\n",
        "from langchain_cohere import ChatCohere, CohereEmbeddings\n",
        "\n",
        "load_dotenv()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Download the \"Attention is all you need\" paper"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2024-09-20 19:19:26--  https://arxiv.org/pdf/1706.03762\n",
            "Resolving arxiv.org (arxiv.org)... 151.101.195.42, 151.101.3.42, 151.101.67.42, ...\n",
            "Connecting to arxiv.org (arxiv.org)|151.101.195.42|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 2215244 (2.1M) [application/pdf]\n",
            "Saving to: ‘1706.03762’\n",
            "\n",
            "1706.03762          100%[===================>]   2.11M  13.3MB/s    in 0.2s    \n",
            "\n",
            "2024-09-20 19:19:26 (13.3 MB/s) - ‘1706.03762’ saved [2215244/2215244]\n",
            "\n"
          ]
        }
      ],
      "source": [
        "!wget https://arxiv.org/pdf/1706.03762\n",
        "!mv 1706.03762 attention_is_all_you_need.pdf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Data Extraction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "text_data = []\n",
        "img_data = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "with fitz.open('attention_is_all_you_need.pdf') as pdf_file:\n",
        "    # Create a directory to store the images\n",
        "    if not os.path.exists(\"extracted_images\"):\n",
        "        os.makedirs(\"extracted_images\")\n",
        "\n",
        "    # Loop through every page in the PDF\n",
        "    for page_number in range(len(pdf_file)):\n",
        "        page = pdf_file[page_number]\n",
        "        \n",
        "        # Get the text on page\n",
        "        text = page.get_text().strip()\n",
        "        text_data.append({\"response\": text, \"name\": page_number+1})\n",
        "        # Get the list of images on the page\n",
        "        images = page.get_images(full=True)\n",
        "\n",
        "        # Loop through all images found on the page\n",
        "        for image_index, img in enumerate(images, start=0):\n",
        "            xref = img[0]  # Get the XREF of the image\n",
        "            base_image = pdf_file.extract_image(xref)  # Extract the image\n",
        "            image_bytes = base_image[\"image\"]  # Get the image bytes\n",
        "            image_ext = base_image[\"ext\"]  # Get the image extension\n",
        "            \n",
        "            # Load the image using PIL and save it\n",
        "            image = Image.open(io.BytesIO(image_bytes))\n",
        "            image.save(f\"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))\n",
        "model = genai.GenerativeModel(model_name=\"gemini-1.5-flash\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Image Captioning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "for img in os.listdir(\"extracted_images\"):\n",
        "    image = Image.open(f\"extracted_images/{img}\")\n",
        "    response = model.generate_content([image, \"You are an assistant tasked with summarizing tables, images and text for retrieval. \\\n",
        "    These summaries will be embedded and used to retrieve the raw text or table elements \\\n",
        "    Give a concise summary of the table or text that is well optimized for retrieval. Table or text or image:\"])\n",
        "    img_data.append({\"response\": response.text, \"name\": img})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Vectostore"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Set embeddings\n",
        "embedding_model = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
        "\n",
        "# Load the document\n",
        "docs_list = [Document(page_content=text['response'], metadata={\"name\": text['name']}) for text in text_data]\n",
        "img_list = [Document(page_content=img['response'], metadata={\"name\": img['name']}) for img in img_data]\n",
        "\n",
        "# Split\n",
        "text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
        "    chunk_size=400, chunk_overlap=50\n",
        ")\n",
        "\n",
        "doc_splits = text_splitter.split_documents(docs_list)\n",
        "img_splits = text_splitter.split_documents(img_list)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Add to vectorstore\n",
        "vectorstore = Chroma.from_documents(\n",
        "    documents=doc_splits + img_splits, # adding the both text and image splits\n",
        "    collection_name=\"multi_model_rag\",\n",
        "    embedding=embedding_model,\n",
        ")\n",
        "\n",
        "retriever = vectorstore.as_retriever(\n",
        "                search_type=\"similarity\",\n",
        "                search_kwargs={'k': 1}, # number of documents to retrieve\n",
        "            )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Query"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "query = \"What is the BLEU score of the Transformer (base model)?\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {},
      "outputs": [],
      "source": [
        "docs = retriever.invoke(query)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The Transformer (base model) achieves a BLEU score of 27.3.\n"
          ]
        }
      ],
      "source": [
        "from langchain_core.output_parsers import StrOutputParser\n",
        "\n",
        "# Prompt\n",
        "system = \"\"\"You are an assistant for question-answering tasks. Answer the question based upon your knowledge. \n",
        "Use three-to-five sentences maximum and keep the answer concise.\"\"\"\n",
        "prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"Retrieved documents: \\n\\n <docs>{documents}</docs> \\n\\n User question: <question>{question}</question>\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# LLM\n",
        "llm = ChatCohere(model=\"command-r-plus\", temperature=0)\n",
        "\n",
        "# Chain\n",
        "rag_chain = prompt | llm | StrOutputParser()\n",
        "\n",
        "# Run\n",
        "generation = rag_chain.invoke({\"documents\":docs[0].page_content, \"question\": query})\n",
        "print(generation)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "![](https://europe-west1-rag-techniques-views-tracker.cloudfunctions.net/rag-techniques-tracker?notebook=all-rag-techniques--multi-model-rag-with-captioning)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}